Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dropout #101

Draft
wants to merge 16 commits into
base: main_perf
Choose a base branch
from
Draft

Dropout #101

wants to merge 16 commits into from

Conversation

micmelesse
Copy link
Collaborator

No description provided.

alexkranias-amd and others added 2 commits December 2, 2024 20:26
This is a combination of 11 commits.

save

fix: dropout=0.0 woorks

feat: dropout restrictions removed. failing tests

test: reduced tests to simple cases

test: failure is due to query + key padding mask NOT varlen itself

feat: varlen dropout fwd passes

fix: varlen bwd dropout works!

test: discovered  bwd error for non-dropout cases for large seqlen

save

save

use triton commit 3ca2f498e98ed7249b82722587c511a5610e00c4 -- now batched layout passes
This is a combination of 63 commits.

pick test case

save philox offsets into metadata

pass offset to ref

common dropout mask

simple droput out mask

start dropout ref. work on returning SD_Mask next with negative numbers

refernce is working

dropout bwd ref faling case

transfer rng_state properly

save changes

one dropout mask function

save

save

minizmize diff

save

use torch.where in backward

save

save

save

dk works!

passes

reference is working. TODO" attn_ref is broken

varlen ref working

attn failing case

with ones. attn_ref matches. fails with randn. we are seeing failure with large sizes from dv.

save

skip attn matrices

compare the masks and find failing case

rm cdiv_fn

put dropout and alibi in common

save

compare masks

save

save

pytorch ref is using tiles

save

save

tl_rand_ref

cache ref dropout mask

new generate_dropout_mask_ref using tiling

issolate failing varlen case

simple dropout

loop on k

print rng_outputs

save

fwd kernel works

save

dv passed

close to dk

simple ref

save

seperate droped and scaled in ref and triton kernel

ref changes

working delta with dp

find failing dv failures

find failing case due to delta

save

delta from dp working

bwd impl green

enable test fwd

save

save

delete kernels
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants